{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Customize a new quantization algorithm\n\nTo write a new quantization algorithm, you can write a class that inherits ``nni.compression.pytorch.Quantizer``.\nThen, override the member functions with the logic of your algorithm. The member function to override is ``quantize_weight``.\n``quantize_weight`` directly returns the quantized weights rather than mask, because for quantization the quantized weights cannot be obtained by applying mask.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from nni.compression.pytorch import Quantizer\n\nclass YourQuantizer(Quantizer):\n    def __init__(self, model, config_list):\n        \"\"\"\n        Suggest you to use the NNI defined spec for config\n        \"\"\"\n        super().__init__(model, config_list)\n\n    def quantize_weight(self, weight, config, **kwargs):\n        \"\"\"\n        quantize should overload this method to quantize weight tensors.\n        This method is effectively hooked to :meth:`forward` of the model.\n\n        Parameters\n        ----------\n        weight : Tensor\n            weight that needs to be quantized\n        config : dict\n            the configuration for weight quantization\n        \"\"\"\n\n        # Put your code to generate `new_weight` here\n        new_weight = ...\n        return new_weight\n\n    def quantize_output(self, output, config, **kwargs):\n        \"\"\"\n        quantize should overload this method to quantize output.\n        This method is effectively hooked to `:meth:`forward` of the model.\n\n        Parameters\n        ----------\n        output : Tensor\n            output that needs to be quantized\n        config : dict\n            the configuration for output quantization\n        \"\"\"\n\n        # Put your code to generate `new_output` here\n        new_output = ...\n        return new_output\n\n    def quantize_input(self, *inputs, config, **kwargs):\n        \"\"\"\n        quantize should overload this method to quantize input.\n        This method is effectively hooked to :meth:`forward` of the model.\n\n        Parameters\n        ----------\n        inputs : Tensor\n            inputs that needs to be quantized\n        config : dict\n            the configuration for inputs quantization\n        \"\"\"\n\n        # Put your code to generate `new_input` here\n        new_input = ...\n        return new_input\n\n    def update_epoch(self, epoch_num):\n        pass\n\n    def step(self):\n        \"\"\"\n        Can do some processing based on the model or weights binded\n        in the func bind_model\n        \"\"\"\n        pass"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Customize backward function\n\nSometimes it's necessary for a quantization operation to have a customized backward function,\nsuch as `Straight-Through Estimator <https://stackoverflow.com/questions/38361314/the-concept-of-straight-through-estimator-ste>`__\\ ,\nuser can customize a backward function as follow:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from nni.compression.pytorch.compressor import Quantizer, QuantGrad, QuantType\n\nclass ClipGrad(QuantGrad):\n    @staticmethod\n    def quant_backward(tensor, grad_output, quant_type):\n        \"\"\"\n        This method should be overrided by subclass to provide customized backward function,\n        default implementation is Straight-Through Estimator\n        Parameters\n        ----------\n        tensor : Tensor\n            input of quantization operation\n        grad_output : Tensor\n            gradient of the output of quantization operation\n        quant_type : QuantType\n            the type of quantization, it can be `QuantType.INPUT`, `QuantType.WEIGHT`, `QuantType.OUTPUT`,\n            you can define different behavior for different types.\n        Returns\n        -------\n        tensor\n            gradient of the input of quantization operation\n        \"\"\"\n\n        # for quant_output function, set grad to zero if the absolute value of tensor is larger than 1\n        if quant_type == QuantType.OUTPUT:\n            grad_output[tensor.abs() > 1] = 0\n        return grad_output\n\nclass _YourQuantizer(Quantizer):\n    def __init__(self, model, config_list):\n        super().__init__(model, config_list)\n        # set your customized backward function to overwrite default backward function\n        self.quant_grad = ClipGrad"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "If you do not customize ``QuantGrad``, the default backward is Straight-Through Estimator. \n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.8"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}